-
Notifications
You must be signed in to change notification settings - Fork 16
Add ListMLE Loss #130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ListMLE Loss #130
Conversation
|
Not sure why tests did not run, commenting so that tests run |
abheesht17
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay, did one pass! Some questions below, let me know what you think.
I'll take a closer look again!
keras_rs/api/layers/__init__.py
Outdated
| from keras_rs.src.layers.embedding.distributed_embedding import DistributedEmbedding as DistributedEmbedding | ||
| from keras_rs.src.layers.embedding.distributed_embedding_config import FeatureConfig as FeatureConfig | ||
| from keras_rs.src.layers.embedding.distributed_embedding_config import TableConfig as TableConfig | ||
| from keras_rs.src.layers.embedding.embed_reduce import EmbedReduce as EmbedReduce | ||
| from keras_rs.src.layers.feature_interaction.dot_interaction import DotInteraction as DotInteraction | ||
| from keras_rs.src.layers.feature_interaction.feature_cross import FeatureCross as FeatureCross | ||
| from keras_rs.src.layers.retrieval.brute_force_retrieval import BruteForceRetrieval as BruteForceRetrieval | ||
| from keras_rs.src.layers.retrieval.hard_negative_mining import HardNegativeMining as HardNegativeMining | ||
| from keras_rs.src.layers.retrieval.remove_accidental_hits import RemoveAccidentalHits as RemoveAccidentalHits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, these shouldn't be getting re-formatted, in my opinion. Are you on the correct Ruff version? Same for other files
ef5ee24 to
63388e9
Compare
|
Hi @abheesht17, Done suggested changes. Could you please review the PR. Thank you |
abheesht17
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This LGTM! I'm going to drop an approval on this, but could you please format the code first?
|
Not sure why we don't have a way to run tests here, need to figure that out before merging |
abheesht17
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noticed a few nits. Instead of using +, -, *, /, let's use Keras ops everywhere
|
@LakshmiKalaKadali - can you look into the test failures above? |
|
Hi @abheesht17, The torch test fails are due to precision mismatch in between actual and expected values. That precision mismatch might be due to the mask AFTER exp(). Applied the mask to set invalid positions to (-1e9) before calling ops.exp(), and removed the masking that happens after exp(). I hope that this resolves the torch test failure. Could you please run the tests. Thank You |
|
The version of torch was updated for the tests. Please rebase to see if it helps. |
e986ef5 to
3da3fa2
Compare
|
I do not understand why the results are different. But the differences you see are significant enough that it has to be a bug. Can you try to debug by printing all intermediate values in |
|
Sure @hertschuh, I will deep dive into the intermediate values. Thank You |
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
157f774 to
8cb2d98
Compare
|
Hi @hertschuh, all the tests are passing now, Could you please take a look. Thank You |
hertschuh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for figuring out the issue!
.github/workflows/actions.yml
Outdated
| - name: Test with pytest | ||
| run: | | ||
| pytest keras_rs/ | ||
| pytest keras_rs/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Undo this file.
keras_rs/src/losses/list_mle_loss.py
Outdated
| # added stable offset before calling sort_by_scores | ||
| list_size = ops.shape(labels_for_sorting)[1] | ||
| indices = ops.arange(list_size) | ||
|
|
||
| indices = ops.expand_dims(indices, axis=0) | ||
| indices = ops.broadcast_to(indices, ops.shape(labels_for_sorting)) | ||
|
|
||
| stable_offset = ops.cast(indices, labels_for_sorting.dtype) * 1e-6 | ||
|
|
||
| labels_for_sorting = ops.subtract(labels_for_sorting, stable_offset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh... do the issue is that Torch's topk is not stable with sorted=True, unlike JAX and TF.
There is not easy fix for topk that would be efficient.
However, do you mind:
- moving this in the
sort_by_scoresinranking_metrics_utils.py - doing it only if
keras.backend.backend() == "torch"andshuffle_tiesis False
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a comment with a reference to these:
- Make topk sort stable pytorch/pytorch#27542
- Enable
torch.topkto supportstableflag pytorch/pytorch#88227
We will be able to remove this code if they ever fix the "stable" issue.
0761a46 to
22757fd
Compare
| # equal scores. We can remove this workaround once PyTorch adds a | ||
| # `stable=True` flag for topk. | ||
|
|
||
| if K.backend() == "torch" and not shuffle_ties: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick: K was how keras was imported many many years ago. But we don't use that style anymore. Simply remove the import and write keras.backend.backend()
hertschuh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Added ListMLELoss code to listwise ranking. This code does not consider Lambda weights.
Here is the gist for verified results with TFRS ListMLELoss